import os.path as osp
import argparse
from sklearn.neighbors import NearestNeighbors
import torch
import json
from lib import DATASETS


def check_nn_map_file(nn_map_file, real_dataset_image_filenames):
    """Check if existing nn map file contains the same real images' filenames.

    Args:
        nn_map_file (str): nn map file
        real_dataset_image_filenames (list): list of real images' filenames

    Returns:
        exists (bool): whether the nn map file exists and contains the correct real images' filenames
    """
    exists = osp.exists(nn_map_file)
    if not exists:
        return exists

    with open(nn_map_file) as f:
        nn_map = json.load(f)
    exists = set(nn_map.keys()) == set(real_dataset_image_filenames)

    return exists


def main():
    """A script for finding the Nearest Neighbor (NN) of each sample in a given real dataset from a pool of fake images
    (as generated by `create_fake_dataset.py`). The NNs will be found in all available features spaces (CLIP [1] and/or
    FaRL [2] and/or DINO [3] and/or ArcFace [4]) depending on their availability in the given real and fake datasets.

    For any given real dataset, a file that contains the map between the real images and the fake NNs will be stored
    under the fake dataset's directory.

    Options:
        -v, --verbose       : set verbose mode on
        --real-dataset      : choose a real dataset (see lib/config.py:DATASETS.keys()) -- features for the real dataset
                              should have first been calculated and stored under datasets/<args.real_dataset>_features/
                              by `extract_features.py`
        --fake-dataset-root : set the fake dataset's root directory (as generated by `create_fake_dataset.py` under
                              datasets/)
        --algorithm         : set algorithm used to compute the nearest neighbors ('auto', 'ball_tree', 'kd_tree',
                              'brute')
        --metric            : set metric to use for distance computation
        --cuda              : use CUDA (default)
        --no-cuda           : do not use CUDA

    References:
        [1] Radford, Alec, et al. "Learning transferable visual models from natural language supervision."
            International Conference on Machine Learning. PMLR, 2021.
        [2] Zheng, Yinglin, et al. "General Facial Representation Learning in a Visual-Linguistic Manner."
            Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2022.
        [3] Caron, Mathilde, et al. "Emerging properties in self-supervised vision transformers." Proceedings of the
            IEEE/CVF International Conference on Computer Vision. 2021.
        [4] Deng, Jiankang, et al. "ArcFace: Additive angular margin loss for deep face recognition." Proceedings of
            the IEEE/CVF conference on computer vision and pattern recognition. 2019.

    """
    parser = argparse.ArgumentParser(
        description="Pair each image of a given real dataset with an image of a given fake dataset")
    parser.add_argument('-v', '--verbose', action='store_true', help="verbose mode on")
    parser.add_argument('--real-dataset', type=str, required=True, choices=DATASETS.keys(), help="real dataset")
    parser.add_argument('--fake-dataset-root', type=str,
                        help="set the fake dataset's root directory "
                             "(as generated by `create_fake_dataset.py` under datasets/)")
    parser.add_argument('--algorithm', default='all', choices=('auto', 'ball_tree', 'kd_tree', 'brute', 'all'),
                        help="set algorithm used to compute the nearest neighbors")
    parser.add_argument('--metric', type=str, default='all', choices=('euclidean', 'cosine', 'all'),
                        help="metric to use for distance computation")
    parser.add_argument('--cuda', dest='cuda', action='store_true', help="use CUDA during training")
    parser.add_argument('--no-cuda', dest='cuda', action='store_false', help="do NOT use CUDA during training")
    parser.set_defaults(cuda=True)

    # Parse given arguments
    args = parser.parse_args()

    NN_ALGORITHMS = [args.algorithm]
    if args.algorithm == 'all':
        NN_ALGORITHMS = ['auto', 'ball_tree', 'kd_tree', 'brute']

    NN_METRICS = [args.metric]
    if args.metric == 'all':
        NN_METRICS = ['euclidean', 'cosine']

    ####################################################################################################################
    ##                                                                                                                ##
    ##                                            [ Real Dataset Features ]                                           ##
    ##                                                                                                                ##
    ####################################################################################################################
    real_dataset_features_dir = osp.join('datasets', 'features', '{}'.format(args.real_dataset))
    if not osp.isdir(real_dataset_features_dir):
        raise NotADirectoryError(
            "Directory of real dataset features ({}) not found -- use `extract_features.py` to create it.".format(
                real_dataset_features_dir))

    if args.verbose:
        print("#. Real dataset features root directory: {}".format(args.fake_dataset_root))

    # Get real dataset image filenames
    with open(osp.join('datasets', 'features', '{}'.format(args.real_dataset), 'image_filenames.txt')) as f:
        content_list = f.readlines()
    real_dataset_image_filenames = [x.strip() for x in content_list]

    if args.verbose:
        print("  \\__real_dataset_image_filenames: {}".format(len(real_dataset_image_filenames)))

    # === CLIP features ===
    clip_real_features = None
    clip_real_features_file = osp.join(real_dataset_features_dir, 'clip_features.pt')
    use_clip = osp.exists(clip_real_features_file)
    if use_clip:
        clip_real_features = torch.load(clip_real_features_file).numpy()

        if args.verbose:
            print("  \\__CLIP features: {}".format(clip_real_features.shape))

    # === FaRL features ===
    farl_real_features = None
    farl_real_features_file = osp.join(real_dataset_features_dir, 'farl_features.pt')
    use_farl = osp.exists(farl_real_features_file)
    if use_farl:
        farl_real_features = torch.load(farl_real_features_file).numpy()

        if args.verbose:
            print("  \\__FaRL features: {}".format(farl_real_features.shape))

    # === DINO features ===
    dino_real_features = None
    dino_real_features_file = osp.join(real_dataset_features_dir, 'dino_features.pt')
    use_dino = osp.exists(dino_real_features_file)
    if use_dino:
        dino_real_features = torch.load(dino_real_features_file).numpy()

        if args.verbose:
            print("  \\__DINO features: {}".format(dino_real_features.shape))

    # === ArcFace features ===
    arcface_real_features = None
    arcface_real_features_file = osp.join(real_dataset_features_dir, 'arcface_features.pt')
    use_arcface = osp.exists(arcface_real_features_file)
    if use_arcface:
        arcface_real_features = torch.load(arcface_real_features_file).numpy()

        if args.verbose:
            print("  \\__ArcFace features: {}".format(arcface_real_features.shape))

    if args.verbose:
        print("#. Finding NNs for the following algorithms and metrics:")
        print("  \\__NN algorithms : {}".format(NN_ALGORITHMS))
        print("  \\__NN metrics    : {}".format(NN_METRICS))
        print("#. Process...")

    for nn_metric in NN_METRICS:
        for nn_algorithm in NN_ALGORITHMS:

            print("  \\__.(metric, algorithm) = ({}, {})".format(nn_metric, nn_algorithm))

            if ((nn_metric == 'cosine') and (nn_algorithm == 'ball_tree')) or \
                    ((nn_metric == 'cosine') and (nn_algorithm == 'kd_tree')):
                print("      \\__.Invalid combination -- Abort!")
                continue

            ############################################################################################################
            ##                                                                                                        ##
            ##                                       [ Fake Dataset Features ]                                        ##
            ##                                                                                                        ##
            ############################################################################################################
            if not osp.isdir(args.fake_dataset_root):
                raise NotADirectoryError

            if args.verbose:
                print("      \\__.Fake dataset root directory: {}".format(args.fake_dataset_root))

            # # Get fake dataset image filenames
            with open(osp.join(args.fake_dataset_root, 'latent_code_hashes.txt')) as f:
                content_list = f.readlines()
            fake_dataset_image_filenames = [x.strip() for x in content_list]

            if args.verbose:
                print("          \\__fake_dataset_image_filenames: {}".format(len(fake_dataset_image_filenames)))

            # Fit NN models on fake data samples
            # === CLIP features ===
            clip_fake_features_file = osp.join(args.fake_dataset_root, 'clip_features.pt')

            clip_nn_map_file = osp.join(args.fake_dataset_root, 'clip_{}_{}_nn_map_{}.json'.format(
                nn_algorithm, nn_metric, args.real_dataset))

            use_clip = osp.exists(clip_fake_features_file) and \
                use_clip and \
                (not check_nn_map_file(clip_nn_map_file, real_dataset_image_filenames))

            nn_model_clip = None
            if use_clip:
                clip_fake_features = torch.load(clip_fake_features_file).numpy()

                if args.verbose:
                    print("          \\__CLIP features: {}".format(clip_fake_features.shape))
                    print("          \\__Fit NN model...", end="")

                nn_model_clip = NearestNeighbors(n_neighbors=1,
                                                 algorithm=nn_algorithm,
                                                 metric=nn_metric
                                                 ).fit(clip_fake_features)
                if args.verbose:
                    print("Done!")

            # === FaRL features ===
            farl_fake_features_file = osp.join(args.fake_dataset_root, 'farl_features.pt')

            farl_nn_map_file = osp.join(args.fake_dataset_root, 'farl_{}_{}_nn_map_{}.json'.format(
                nn_algorithm, nn_metric, args.real_dataset))

            use_farl = osp.exists(farl_fake_features_file) and \
                use_farl and \
                (not check_nn_map_file(farl_nn_map_file, real_dataset_image_filenames))

            nn_model_farl = None
            if use_farl:
                farl_fake_features = torch.load(farl_fake_features_file).numpy()

                if args.verbose:
                    print("          \\__FaRL features: {}".format(farl_fake_features.shape))
                    print("          \\__Fit NN model...", end="")

                nn_model_farl = NearestNeighbors(n_neighbors=1,
                                                 algorithm=nn_algorithm,
                                                 metric=nn_metric
                                                 ).fit(farl_fake_features)
                if args.verbose:
                    print("Done!")

            # === DINO features ===
            dino_fake_features_file = osp.join(args.fake_dataset_root, 'dino_features.pt')

            dino_nn_map_file = osp.join(args.fake_dataset_root, 'dino_{}_{}_nn_map_{}.json'.format(
                nn_algorithm, nn_metric, args.real_dataset))

            use_dino = osp.exists(dino_fake_features_file) and \
                use_dino and \
                (not check_nn_map_file(dino_nn_map_file, real_dataset_image_filenames))

            nn_model_dino = None
            if use_dino:
                dino_fake_features = torch.load(dino_fake_features_file).numpy()

                if args.verbose:
                    print("          \\__DINO features: {}".format(dino_fake_features.shape))
                    print("          \\__Fit NN model...", end="")

                nn_model_dino = NearestNeighbors(n_neighbors=1,
                                                 algorithm=nn_algorithm,
                                                 metric=nn_metric
                                                 ).fit(dino_fake_features)
                if args.verbose:
                    print("Done!")

            # === ArcFace features ===
            arcface_fake_features_file = osp.join(args.fake_dataset_root, 'arcface_features.pt')

            arcface_nn_map_file = osp.join(args.fake_dataset_root, 'arcface_{}_{}_nn_map_{}.json'.format(
                nn_algorithm, nn_metric, args.real_dataset))

            use_arcface = osp.exists(arcface_fake_features_file) and \
                use_arcface and \
                (not check_nn_map_file(arcface_nn_map_file, real_dataset_image_filenames))

            nn_model_arcface = None
            if use_arcface:
                arcface_fake_features = torch.load(arcface_fake_features_file).numpy()

                if args.verbose:
                    print("          \\__ArcFace features: {}".format(arcface_fake_features.shape))
                    print("          \\__Fit NN model...", end="")

                nn_model_arcface = NearestNeighbors(n_neighbors=1,
                                                    algorithm=nn_algorithm,
                                                    metric=nn_metric
                                                    ).fit(arcface_fake_features)
                if args.verbose:
                    print("Done!")

            if args.verbose:
                print("          \\__.Find NNs...")

            # === CLIP features ===
            if use_clip:
                if args.verbose:
                    print("              \\__CLIP features...", end="")
                _, indices = nn_model_clip.kneighbors(clip_real_features)
                if args.verbose:
                    print("Done!")

                # Build NN map dictionary
                nn_map = dict()
                for i in range(len(real_dataset_image_filenames)):
                    nn_map.update({real_dataset_image_filenames[i]: fake_dataset_image_filenames[int(indices[i])]})

                # Save nn map
                with open(clip_nn_map_file, "w") as f:
                    json.dump(nn_map, f)

            # === FaRL features ===
            if use_farl:
                if args.verbose:
                    print("              \\__FaRL features...", end="")
                _, indices = nn_model_farl.kneighbors(farl_real_features)
                if args.verbose:
                    print("Done!")

                # Build NN map dictionary
                nn_map = dict()
                for i in range(len(real_dataset_image_filenames)):
                    nn_map.update({real_dataset_image_filenames[i]:
                                   fake_dataset_image_filenames[int(indices[i])]})

                # Save nn map
                with open(farl_nn_map_file, "w") as f:
                    json.dump(nn_map, f)

            # === DINO features ===
            if use_dino:
                if args.verbose:
                    print("              \\__DINO features...", end="")
                _, indices = nn_model_dino.kneighbors(dino_real_features)
                if args.verbose:
                    print("Done!")

                # Build NN map dictionary
                nn_map = dict()
                for i in range(len(real_dataset_image_filenames)):
                    nn_map.update({real_dataset_image_filenames[i]: fake_dataset_image_filenames[int(indices[i])]})

                # Save nn map
                with open(dino_nn_map_file, "w") as f:
                    json.dump(nn_map, f)

            # === ArcFace features ===
            if use_arcface:
                if args.verbose:
                    print("              \\__ArcFace features...", end="")
                _, indices = nn_model_arcface.kneighbors(arcface_real_features)
                if args.verbose:
                    print("Done!")

                # Build NN map dictionary
                nn_map = dict()
                for i in range(len(real_dataset_image_filenames)):
                    nn_map.update({real_dataset_image_filenames[i]: fake_dataset_image_filenames[int(indices[i])]})

                # Save nn map
                with open(arcface_nn_map_file, "w") as f:
                    json.dump(nn_map, f)


if __name__ == '__main__':
    main()
